import csv
import itertools

if __name__ == '__main__':
    run_name = "InitialDNNSmall"

    learner_type_models_obs = [("Individual_Q", None, "FullObsDiscrete"), ("Individual_Deep_Q", "simple_mlp", "2DObs"),
                               #  ("Individual_Deep_Q", "simple_cnn", "2DObs")
                               ]
    map_names = ["Pentagon", "ISR", "SUNY", "MIT"]
    randomize_starts = [False]
    learner_anneal_eps = [(0.2, 0.05)]
    learner_evaluation_epsilons = [0, 0.05]
    num_runs = 3
    collision_penalties = [0, -25, -50, -75, -100]

    shields_punish_unsafe_orig_actions_collision_penalties = [("centralized", (True, -10), -30)] + [
        ("none", (False, 0), i) for i in collision_penalties]

    with open(f"../../parallel_configs/{run_name}.csv", "w") as file:
        writer = csv.DictWriter(file, ["run_name", "shield", "evaluation_shield", "punish_unsafe_orig_action",
                                       "punish_unsafe_orig_action_modifier", "randomize_starts", "map_type",
                                       "grid_world_map_name",
                                       "grid_world_obs_type", "learner_type", "learner_deep_network_model",
                                       "learner_anneal_eps_start",
                                       "learner_anneal_eps_finish", "learner_evaluation_epsilon", "seed",
                                       "grid_world_collision_penalty", "max_num_episodes"])
        writer.writeheader()
        for run_type_idx, (
                (learner_type, learner_model, obs_type), map_name, random_start, (eps_anneal_start, eps_anneal_finish),
                learner_evaluation_epsilon,
                (shield, (punish_unsafe_action, unsafe_action_rew_modifier), collision_penalty)) in enumerate(
            itertools.product(learner_type_models_obs, map_names, randomize_starts,
                              learner_anneal_eps, learner_evaluation_epsilons,
                              shields_punish_unsafe_orig_actions_collision_penalties)):

            for run_num_of_same_type in range(num_runs):
                global_run_idx = run_type_idx * num_runs + run_num_of_same_type

                concat_run_name = run_name + "/" + str(global_run_idx) + "_" + str(run_type_idx) + "_" + str(
                    run_num_of_same_type)

                writer.writerow({
                    "run_name": concat_run_name,
                    "shield": shield,
                    "evaluation_shield": "none",
                    "punish_unsafe_orig_action": punish_unsafe_action,
                    "punish_unsafe_orig_action_modifier": unsafe_action_rew_modifier,
                    "randomize_starts": random_start,
                    "map_type": "GridWorld",
                    "grid_world_map_name": map_name,
                    "grid_world_obs_type": obs_type,
                    "learner_type": learner_type,
                    "learner_deep_network_model": learner_model,
                    "learner_anneal_eps_start": eps_anneal_start,
                    "learner_anneal_eps_finish": eps_anneal_finish,
                    "learner_evaluation_epsilon": learner_evaluation_epsilon,
                    "seed": run_num_of_same_type,
                    "grid_world_collision_penalty": collision_penalty,
                    "max_num_episodes": int(1e3)
                })
